{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision \n",
    "from torchvision import transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.autograd import Variable "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self,D_in,H,D_out):\n",
    "        super(Encoder,self).__init__()\n",
    "        self.linear1 = nn.Linear(D_in,H)\n",
    "        self.linear2 = nn.Linear(H,D_out)\n",
    "    def forward(self,x):\n",
    "        x = F.relu(self.linear1(x))\n",
    "        return F.relu(self.linear2(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self,D_in,H,D_out):\n",
    "        super(Decoder,self).__init__()\n",
    "        self.linear1 = nn.Linear(D_in,H)\n",
    "        self.linear2 = nn.Linear(H,D_out)\n",
    "    def forward(self,x):\n",
    "        x = F.relu(self.linear1(x))\n",
    "        return F.relu(self.linear2(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self,encoder,decoder):\n",
    "        super(VAE,self).__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self._enc_mu = nn.Linear(100,8)\n",
    "        self._enc_log_sigma = nn.Linear(100,8)\n",
    "        \n",
    "    def _sample_latent(self,h_enc):\n",
    "        mu = self._enc_mu(h_enc)\n",
    "        log_sigma = self._enc_log_sigma(h_enc)\n",
    "        sigma = torch.exp(log_sigma)\n",
    "        std_z = torch.from_numpy(np.random.normal(0,1,size = sigma.size())).float()\n",
    "        \n",
    "        self.z_mean = mu\n",
    "        self.z_sigma = sigma\n",
    "        \n",
    "        return mu + sigma * Variable(std_z,requires_grad = False)\n",
    "    \n",
    "    def forward(self,state):\n",
    "        h_enc = self.encoder(state)\n",
    "        z = self._sample_latent(h_enc)\n",
    "        output = self.decoder(z)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NUM_SAMPLES: 60000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/caijie/anaconda3/lib/python3.6/site-packages/torch/tensor.py:263: UserWarning: non-inplace resize is deprecated\n",
      "  warnings.warn(\"non-inplace resize is deprecated\")\n",
      "/Users/caijie/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:32: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.066152\n",
      "epoch: 1, loss: 0.065962\n",
      "epoch: 2, loss: 0.064354\n",
      "epoch: 3, loss: 0.066917\n",
      "epoch: 4, loss: 0.073372\n",
      "epoch: 5, loss: 0.069688\n",
      "epoch: 6, loss: 0.061277\n",
      "epoch: 7, loss: 0.071178\n",
      "epoch: 8, loss: 0.068341\n",
      "epoch: 9, loss: 0.063634\n",
      "epoch: 10, loss: 0.075112\n",
      "epoch: 11, loss: 0.070322\n",
      "epoch: 12, loss: 0.068794\n",
      "epoch: 13, loss: 0.067258\n",
      "epoch: 14, loss: 0.064795\n",
      "epoch: 15, loss: 0.069586\n",
      "epoch: 16, loss: 0.070698\n",
      "epoch: 17, loss: 0.061841\n",
      "epoch: 18, loss: 0.072216\n",
      "epoch: 19, loss: 0.069646\n",
      "epoch: 20, loss: 0.063879\n",
      "epoch: 21, loss: 0.064958\n",
      "epoch: 22, loss: 0.065745\n",
      "epoch: 23, loss: 0.062291\n",
      "epoch: 24, loss: 0.065736\n",
      "epoch: 25, loss: 0.067382\n",
      "epoch: 26, loss: 0.068734\n",
      "epoch: 27, loss: 0.065126\n",
      "epoch: 28, loss: 0.066158\n",
      "epoch: 29, loss: 0.072401\n",
      "epoch: 30, loss: 0.065767\n",
      "epoch: 31, loss: 0.062006\n",
      "epoch: 32, loss: 0.069818\n",
      "epoch: 33, loss: 0.066680\n",
      "epoch: 34, loss: 0.064282\n",
      "epoch: 35, loss: 0.070074\n",
      "epoch: 36, loss: 0.071773\n",
      "epoch: 37, loss: 0.071648\n",
      "epoch: 38, loss: 0.068634\n",
      "epoch: 39, loss: 0.065071\n",
      "epoch: 40, loss: 0.067783\n",
      "epoch: 41, loss: 0.063087\n",
      "epoch: 42, loss: 0.069400\n",
      "epoch: 43, loss: 0.066577\n",
      "epoch: 44, loss: 0.067379\n",
      "epoch: 45, loss: 0.063353\n",
      "epoch: 46, loss: 0.064962\n",
      "epoch: 47, loss: 0.062714\n",
      "epoch: 48, loss: 0.068711\n",
      "epoch: 49, loss: 0.067809\n",
      "epoch: 50, loss: 0.069558\n",
      "epoch: 51, loss: 0.073721\n",
      "epoch: 52, loss: 0.069299\n",
      "epoch: 53, loss: 0.062651\n",
      "epoch: 54, loss: 0.072275\n",
      "epoch: 55, loss: 0.067896\n",
      "epoch: 56, loss: 0.064450\n",
      "epoch: 57, loss: 0.060585\n",
      "epoch: 58, loss: 0.065866\n",
      "epoch: 59, loss: 0.069795\n",
      "epoch: 60, loss: 0.065034\n",
      "epoch: 61, loss: 0.069902\n",
      "epoch: 62, loss: 0.063962\n",
      "epoch: 63, loss: 0.066541\n",
      "epoch: 64, loss: 0.069445\n",
      "epoch: 65, loss: 0.061591\n",
      "epoch: 66, loss: 0.060550\n",
      "epoch: 67, loss: 0.068148\n",
      "epoch: 68, loss: 0.067919\n",
      "epoch: 69, loss: 0.064506\n",
      "epoch: 70, loss: 0.064197\n",
      "epoch: 71, loss: 0.067403\n",
      "epoch: 72, loss: 0.069086\n",
      "epoch: 73, loss: 0.068950\n",
      "epoch: 74, loss: 0.070545\n",
      "epoch: 75, loss: 0.063615\n",
      "epoch: 76, loss: 0.064511\n",
      "epoch: 77, loss: 0.073136\n",
      "epoch: 78, loss: 0.068302\n",
      "epoch: 79, loss: 0.070272\n",
      "epoch: 80, loss: 0.077405\n",
      "epoch: 81, loss: 0.062408\n",
      "epoch: 82, loss: 0.061934\n",
      "epoch: 83, loss: 0.067034\n",
      "epoch: 84, loss: 0.073439\n",
      "epoch: 85, loss: 0.065788\n",
      "epoch: 86, loss: 0.064530\n",
      "epoch: 87, loss: 0.064150\n",
      "epoch: 88, loss: 0.063736\n",
      "epoch: 89, loss: 0.067958\n",
      "epoch: 90, loss: 0.068397\n",
      "epoch: 91, loss: 0.071213\n",
      "epoch: 92, loss: 0.066068\n",
      "epoch: 93, loss: 0.070116\n",
      "epoch: 94, loss: 0.073254\n",
      "epoch: 95, loss: 0.072912\n",
      "epoch: 96, loss: 0.063698\n",
      "epoch: 97, loss: 0.068238\n",
      "epoch: 98, loss: 0.058176\n",
      "epoch: 99, loss: 0.065383\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPwAAAD6CAYAAACF8ip6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADrZJREFUeJzt3U+IneUVx/HfiUnMJGYm88dxkiEm\nSjBZiG6KFFqwCShibaG4kSIVW8i2IHQpdNNFcVUQbGfVIkJBRKiCFI1usqgyWSRoQGMgCRPzx8k4\nmWgmyZg8XcwbGie559zef/POnO8HQm7mzDv3mTf55b1zz/s8j5VSBCCHNcs9AAC9Q+CBRAg8kAiB\nBxIh8EAiBB5IhMADiRB4IBECDySytttPYGbcygd033Qp5d7ok/7vK7yZbTCzd83ssJm9bmbW2vgA\ndNDJZj6plZf0z0uaKqU8KmlQ0hMtfA0Ay6CVwO+T9H71+ENJezs3HADd1ErghyVdrB7PSRpa+glm\ntt/MJs1ssp3BAeisVt60m5Y0UD0eqP78A6WUCUkTEm/aAXXSyhX+gKQnq8f7JH3UueEA6KZWAv+G\npHEzOyJpRov/AQBYAf7vl/SllKuSnunCWAB0GXfaAYkQeCARAg8kQuCBRAg8kAiBBxIh8EAiBB5I\nhMADiRB4IBECDyRC4IFECDyQCIEHEun6MtWon4GBAbd+8eJFtx4ZGxtrWDt79mxbXxvt4QoPJELg\ngUQIPJAIgQcSIfBAIgQeSITAA4nQh1+h7r///oa1zZs3u8dGffj+/n63vmnTJrfubSj83Xffucde\nunTJrc/Ozrr1ubm5hrVTp065x2bAFR5IhMADiRB4IBECDyRC4IFECDyQCIEHEqEPv0x27tzp1kdG\nRtz6tm3bWqpJ0vj4uFv35rNLcR9+YWGhYW1mZsY9Npov/9VXX7n1qamphrXo/oTPPvvMra8GLV3h\nzewpM5sys4PVr92dHhiAzmvnCv9aKeVPHRsJgK5r52f4Z83sEzN7y7x7KQHURquBPy7p5VLKY5K2\nSnr81qKZ7TezSTObbHeAADqn1Zf0M5I+qB6fkDR6a7GUMiFpQpLMrLQ6OACd1eoV/iVJz5nZGkkP\nS/q0c0MC0C2tBv5VSS9K+ljS26WUo50bEoBuaeklfSnljKSfdXYo9eP1m++55x732MHBQbce9bq9\n+e6S9MADDzSsPfjggy0fK0n33XefW+/r63PrV69ebVi7cOGCe+zp06fd+pdffunW169f79Y93ril\neL3+r7/+uuXn7hXutAMSIfBAIgQeSITAA4kQeCARAg8kwvRYh7ek8vbt291jR0dH3fqOHTvc+q5d\nu9z6nj17GtaiqbfR9Nlomeo1a/zrxLVr1xrWNmzY4B67dq3/T/L77793697fWdRWi+pRS3Al4AoP\nJELggUQIPJAIgQcSIfBAIgQeSITAA4ms6j581PON6t7WxFGvOpr+GvXxoymu3vFbtmxxj71x44Zb\nj6awXr9+3a2389zr1q1z6+1shR0dG037je6dOHnypFuvA67wQCIEHkiEwAOJEHggEQIPJELggUQI\nPJDIqu7DX7lyxa1HS00/9NBDDWvRlslRL3x4eNitDw0NuXVv3ng0r9vbUlmSvv32W7ce8frZGzdu\ndI+N5to//fTTbv2VV15x655oi8RSVv4mSlzhgUQIPJAIgQcSIfBAIgQeSITAA4kQeCCRVd2Hj0xP\nT7t1r1cebUsc9ZujPn7UE56bm2tYm5mZcY+Nvu9o2+S7777brXv3IIyMjLjHRvPh33zzTbd+9OjR\nhrX5+Xn3WG89fSleE38laOoKb2brzOyd6vEGM3vXzA6b2esW/csEUBth4M2sT9IhSU9UH3pe0lQp\n5VFJg7d8HEDNhYEvpcyXUh6RdPN+zH2S3q8efyhpb5fGBqDDWnnTbljSzZu15yTddtO3me03s0kz\nm2xncAA6q5U37aYl3VwpcKD68w+UUiYkTUiSma38GQfAKtHKFf6ApCerx/skfdS54QDoplYC/4ak\ncTM7ImlGi/8BAFgBmn5JX0rZVf1+VdIzXRtRD0Vz1j3RvO2oVx318RcWFty6tw/6uXPn3GO99fal\n+HuL1vP3vvfovERzzqM1Dry5/JcuXWrra0dr6q8E3GkHJELggUQIPJAIgQcSIfBAIgQeSCT19Nio\nPbVt27aGtaitFm09HB0f8aawRi29aGyDg4NuPdoKe3R0tGEtWho82qo6mrrrteUuX77sHhtNjz17\n9qxbXwm4wgOJEHggEQIPJELggUQIPJAIgQcSIfBAIqn78BFvS+Z2pohGX7sZ3mLBmzdvdo+Ntqoe\nHx93616fXfJ77dEU02+++catR71yr74atntuF1d4IBECDyRC4IFECDyQCIEHEiHwQCIEHkiEPrzD\nW6452tY42lQ32no46hl7zx/1yXfu3OnWvXUAJGlo6LbdxX7A67V721xL0vXr1916NNff+zuL7n2I\n6lu3bnXr0TLY3lz9XuEKDyRC4IFECDyQCIEHEiHwQCIEHkiEwAOJ0Id3tLN2fLR+erRGetTH98bW\n39/vHhvNl4/m+kdz2r2trNvtVXezD3/XXXe59TNnzrj1laCpK7yZrTOzd6rHT5nZlJkdrH7t7u4Q\nAXRKeIU3sz5JH0t66JYPv1ZK+VPXRgWgK8IrfCllvpTyiKSpWz78rJl9YmZvWfTaE0BttPKm3XFJ\nL5dSHpO0VdLjSz/BzPab2aSZTbY7QACd08qbdjOSPqgen5B020yNUsqEpAlJMjNWDgRqopUr/EuS\nnjOzNZIelvRpZ4cEoFtaCfyrkl7U4ht5b5dSjnZ2SAC6pemX9KWUXdXvZyT9rFsD6qVo/XWvpxvN\n275y5Ypbv3jxoluP1l/ftGlTw1rUR4/Wfo/eh4362d7YZ2dn3WO9Hr4Un3dv7NH3FX3t1YA77YBE\nCDyQCIEHEiHwQCIEHkiEwAOJpJ4ee/r06ZbrO3bscI+dn59361F7Kmrrea2vqL3ktRuleApqtBW2\nt8R2u9Nfo9aaV3/vvffcY/fs2ePWR0ZG3Pr09LRbrwOu8EAiBB5IhMADiRB4IBECDyRC4IFECDyQ\nyKruw3tTSCVpbGzMrXvTZ6PtnKM+fLtLJnt9+mgr62gK6saNG916u2P3RH32aIlsbxvuvXv3usee\nOnXKra+EPnuEKzyQCIEHEiHwQCIEHkiEwAOJEHggEQIPJLKq+/BRvznqJ3tbMke97qje19fn1qMt\nnb3jBwYG2vra0f0LUd0T3Z/g9dGleJvtubm5hrVoq+poi+/VgCs8kAiBBxIh8EAiBB5IhMADiRB4\nIBECDyQS9uFtcYLy3yXtlnRe0q8l/VPSdklHJP2mRJPDVyhvbrbXo5fiXni0xvnQ0JBb37JlS0u1\nZr52f3+/W4++d2/t+Xa30b5w4YJb97bCjvrw0Rbdq0EzV/ifSFpbSvmxpH5Jv5U0VUp5VNKgpCe6\nOD4AHdRM4M9J+kv1+JqkP0p6v/rzh5L8ZUQA1Eb4kr6UckySzOxXktZLOiTp5uuuOS2+1AewAjT1\npp2Z/VLS7yX9Qos/x9/8AXVA0m0LfZnZfjObNLPJTg0UQPvCwJvZmKQ/SPp5KeWSpAOSnqzK+yR9\ntPSYUspEKeVHpZQfdXKwANrTzBX+BUlbJf3bzA5KWidp3MyOSJrR4n8AAFaAZn6G/7OkPy/58N+6\nM5ze+vzzz9261zqLtmSOtlSOWl9R286rR223qGUYTRuOph17rbNoKeiTJ0+69WiL73PnzjWseVNn\nJen8+fNufTXgxhsgEQIPJELggUQIPJAIgQcSIfBAIgQeSGRVL1MdiXrdXj856qO3M71VkkZHR926\ntwx2tF1ztByzN8VUks6cOePWjx8/3rAW3ftw7Ngxtz41NeXWDx8+7Naz4woPJELggUQIPJAIgQcS\nIfBAIgQeSITAA4mk7sNPT9+2OlfT9WjOeCTqhS8sLLh1bynoaL579LVnZ2fdetQL9+a0R8dG9WiZ\n6sHBwYa16P6CDLjCA4kQeCARAg8kQuCBRAg8kAiBBxIh8EAi1u2dns1sVW4lHdm+fbtb9/rFkjQ8\nPOzW77333oa1aE38GzduuPXLly+79ahP763/Hm3Z/MUXX7j1aB2CaO35VexQMzs9cYUHEiHwQCIE\nHkiEwAOJEHggEQIPJELggUTCPryZmaS/S9ot6bykCUl/lXSi+pTflVIaLjaetQ/frrGxMbe+fv36\nhrXFv7LGoj589G8imrPuoY/eNR3rw/9E0tpSyo8l9Uu6Iem1UspPq1/+zgIAaqOZwJ+T9Jfq8bXq\n92fN7BMze8uiywmA2ggDX0o5Vkr5xMx+JWm9pOOSXi6lPCZpq6THlx5jZvvNbNLMJjs+YgAta2ph\nNjP7paTfS/qFFkN/oiqdkHTbJmillAkt/qzPz/BAjYRXeDMbk/QHST8vpVyS9JKk58xsjaSHJX3a\n3SEC6JRmfoZ/QYsv3f9tZgclXZb0oqSPJb1dSjnaxfEB6CCmxwKrA9NjAfwQgQcSIfBAIgQeSITA\nA4kQeCARAg8kQuCBRAg8kAiBBxIh8EAiBB5IhMADiRB4IBECDyTS1BJXbZqWdPKWP49UH6sjxtaa\nuo6truOSOj+2Hc18UtcXwLjtCc0mm5movxwYW2vqOra6jktavrHxkh5IhMADiSxH4CeW4Tmbxdha\nU9ex1XVc0jKNrec/wwNYPrykBxLpWeDNbIOZvWtmh83s9TrtSWdmT5nZlJkdrH7tXu4xSZKZrTOz\nd6rHtTp/S8ZWi/Nni/5hZv8xs3+Z2T11OWd3GNszy3HOenmFf17SVCnlUUmDkp7o4XM3o1Y74ppZ\nn6RD+t95qs35u8PYpHqcv6U7Hf9WNTlndxjbsuzC3MvA75P0fvX4Q0l7e/jczajVjrillPlSyiOS\nbm7GXpvzd4exSfU4f0t3Ov6janLOVJNdmHsZ+GFJF6vHc5KGevjckXBH3Brg/AXusNPxIdXknLWy\nC3M39DLw05IGqscDqtctjzOSPqgen9AddsStAc5fE5bsdHxeNTpnS8Y2rWU4Z70M/AFJT1aP90n6\nqIfPHVkJO+Jy/gJ32Om4NuesLrsw9zLwb0gaN7MjWrwiHOjhc0deVf13xOX8xZbudLxO9TlntdiF\nmRtvgES48QZIhMADiRB4IBECDyRC4IFECDyQCIEHEvkvPDpWmSLJhDgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def latent_loss(mean,std):\n",
    "    mean_sq = mean*mean\n",
    "    std_sq = std*std\n",
    "    return 0.5 * torch.mean(mean_sq + std_sq - torch.log(std_sq) - 1)\n",
    "    \n",
    "if __name__=='__main__':\n",
    "    batch_size = 32\n",
    "    input_size = 28*28\n",
    "    transform = transforms.Compose([transforms.ToTensor()])\n",
    "    mnist = torchvision.datasets.MNIST('./',download=True,transform=transform)\n",
    "    dataloader = torch.utils.data.DataLoader(mnist,batch_size = batch_size,shuffle = True,num_workers = 2)\n",
    "    print('NUM_SAMPLES:',len(mnist))\n",
    "    \n",
    "    encoder = Encoder(input_size,100,100)\n",
    "    decoder = Decoder(8,100,input_size)\n",
    "    vae = VAE(encoder,decoder)\n",
    "    \n",
    "    criterion = nn.MSELoss()\n",
    "    optimizer = optim.Adam(vae.parameters(),lr = 0.001)\n",
    "    l = None\n",
    "    \n",
    "    for epoch in range(100):\n",
    "        for i,data in enumerate(dataloader,0):\n",
    "            inputs,classes = data\n",
    "            inputs,classes = Variable(inputs.resize(batch_size,input_size)),Variable(classes)\n",
    "            dec = vae(inputs)\n",
    "            optimizer.zero_grad()\n",
    "            ll = latent_loss(vae.z_mean,vae.z_sigma)\n",
    "            loss = criterion(dec,inputs)+ll\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            l = loss.data[0]\n",
    "        print('epoch: %d, loss: %f' % (epoch,l))\n",
    "        \n",
    "    plt.imshow(vae(inputs).data[0].numpy().reshape(28, 28), cmap='gray')\n",
    "    plt.show(block=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
